import json
import random
import numpy as np
import torch
import os
import argparse
import faiss
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def build_tokenizer(model_name):
    model = SentenceTransformer(model_name)
    tokenizer = model.tokenizer
    return tokenizer, model


def load_vocab(tokenizer):
    vocab_ids = list(range(tokenizer.vocab_size))
    return vocab_ids

def construct_faiss_index(vocab_ids, tokenizer, model, use_cache, cache_prefix):
    emb_cache_path = f"{cache_prefix}_vocab_embs.npy"
    list_cache_path = f"{cache_prefix}_vocab_list.json"

    # vocab_tokens will be a list of token strings, indexed by their ID
    vocab_tokens = tokenizer.convert_ids_to_tokens(vocab_ids)

    if use_cache and os.path.exists(emb_cache_path) and os.path.exists(list_cache_path):
        vocab_embs = np.load(emb_cache_path)
        with open(list_cache_path, "r", encoding="utf-8") as f:
            # Ensure vocab_tokens loaded from cache is consistent (it should be if generated by this script)
            cached_vocab_tokens = json.load(f)
            if len(cached_vocab_tokens) == tokenizer.vocab_size: # Basic check
                vocab_tokens = cached_vocab_tokens
            else:
                vocab_embs = model.encode(vocab_tokens, convert_to_numpy=True, batch_size=512, show_progress_bar=True)
                np.save(emb_cache_path, vocab_embs)
                with open(list_cache_path, "w", encoding="utf-8") as f:
                    json.dump(vocab_tokens, f, ensure_ascii=False, indent=2)

    else:
        vocab_embs = model.encode(vocab_tokens, convert_to_numpy=True, batch_size=512, show_progress_bar=True)
        np.save(emb_cache_path, vocab_embs)
        with open(list_cache_path, "w", encoding="utf-8") as f:
            json.dump(vocab_tokens, f, ensure_ascii=False, indent=2)

    dim = vocab_embs.shape[1]
    index = faiss.IndexFlatIP(dim)
    faiss.normalize_L2(vocab_embs)
    index.add(vocab_embs)

    return index, vocab_tokens, vocab_embs

def _construct_hard_negative_token_mix(pos_ids, neg_ids, tokenizer, max_length=128):
    r = random.uniform(0.2, 0.8)
    pos_len = int(r * max_length)
    neg_len = max_length - pos_len

    if len(pos_ids) > pos_len:
        start = random.randint(0, len(pos_ids) - pos_len)
        pos_ids = pos_ids[start:start + pos_len]
    else:
        pos_ids = pos_ids[:pos_len]

    if len(neg_ids) > neg_len:
        start = random.randint(0, len(neg_ids) - neg_len)
        neg_ids = neg_ids[start:start + neg_len]
    else:
        neg_ids = neg_ids[:neg_len]

    insert_pos = random.randint(0, len(pos_ids))
    mixed_ids = pos_ids[:insert_pos] + neg_ids + pos_ids[insert_pos:]
    mixed_ids = mixed_ids[:max_length]

    return {
        "tokens": tokenizer.convert_ids_to_tokens(mixed_ids),
        "meta": {
            "ratio": round(r, 4),
            "insert_pos": insert_pos,
            "pos_tokens": tokenizer.convert_ids_to_tokens(pos_ids),
            "neg_tokens": tokenizer.convert_ids_to_tokens(neg_ids)
        }
    }



def construct_negatives(input_path, output_path, model_name, min_tokens, max_tokens, data_limit, use_cache=False, seed=42, top_k=1000):
    set_seed(seed)

    _, model = build_tokenizer(model_name)
    tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1")
    model = model.cuda()
    model.eval()


    with open(input_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    if data_limit > 0:
        data = random.sample(data, min(data_limit, len(data)))

    answer_to_indices = {}
    for idx, item in enumerate(data):
        ans = item["answer"]
        answer_to_indices.setdefault(ans, []).append(idx)

    vocab_ids = load_vocab(tokenizer) # This is list(range(tokenizer.vocab_size))
    # cache_prefix = model_name.replace("/", "_")
    cache_prefix = "DeepSeek-R1"
    # vocab_tokens is a list of token strings, where index == token_id
    index, vocab_tokens_list, vocab_embs = construct_faiss_index(vocab_ids, tokenizer, model, use_cache, cache_prefix)

    for idx, item in tqdm(enumerate(data), total=len(data)):
        reason = item["reason"]
        answer = item["answer"]
        reason_token_ids = tokenizer.encode(reason, add_special_tokens=False)
        reason_token_id_set = set(reason_token_ids) 

        item["pos_token_texts_list"] = []
        for _ in range(4):
            if len(reason_token_ids) >= min_tokens:
                pos_num = random.randint(min_tokens, min(max_tokens, len(reason_token_ids)))
                sampled_ids = random.sample(reason_token_ids, pos_num)
            else:
                sampled_ids = reason_token_ids
            item["pos_token_texts_list"].append(tokenizer.convert_ids_to_tokens(sampled_ids))

        item["negatives"] = {}

        # NEG TYPE 1
        other_answers = [a for a in answer_to_indices if a != answer]
        rand_token_ids_type1 = []
        if other_answers: # Ensure there are other answers to pick from
            while True:
                rand_ans = random.choice(other_answers)
                rand_idx = random.choice(answer_to_indices[rand_ans])
                rand_token_ids_type1 = tokenizer.encode(data[rand_idx]["reason"], add_special_tokens=False)
                if len(rand_token_ids_type1) >= min_tokens:
                    break
                if not any(len(tokenizer.encode(data[random.choice(answer_to_indices[oa])]["reason"], add_special_tokens=False)) >= min_tokens for oa in other_answers):
                    break
        
        if rand_token_ids_type1:
            num = random.randint(min_tokens, min(max_tokens, len(rand_token_ids_type1)))
            neg_type_1_sampled_ids = random.sample(rand_token_ids_type1, num)
            item["negatives"]["neg_type_1_tokens"] = tokenizer.convert_ids_to_tokens(neg_type_1_sampled_ids)
        else:
            item["negatives"]["neg_type_1_tokens"] = []


        # NEG TYPE 2 / 3
        reason_emb = model.encode(reason, convert_to_numpy=True).reshape(1, -1)
        faiss.normalize_L2(reason_emb)
        # Now reason_token_id_set is defined and can be used
        D, I = index.search(reason_emb, max_tokens * 2 + len(reason_token_id_set))

        neg_type_2_ids_candidates, neg_type_3_ids_candidates = [], []
        # reason_token_id_set = set(reason_token_ids) # This was the old, problematic position

        for retrieved_idx in I[0]:
            token_id = retrieved_idx
            if token_id not in reason_token_id_set and token_id not in neg_type_2_ids_candidates:
                neg_type_2_ids_candidates.append(token_id)
            if token_id not in neg_type_3_ids_candidates:
                 neg_type_3_ids_candidates.append(token_id)
            
            if len(neg_type_2_ids_candidates) >= max_tokens * 2 and len(neg_type_3_ids_candidates) >= max_tokens * 2 :
                break
        
        # NEG TYPE 4
        # vocab_ids is list(range(tokenizer.vocab_size))
        num_4 = min(len(vocab_ids), random.randint(min_tokens, max_tokens))
        neg_type_4_sampled_ids = random.sample(vocab_ids, num_4)
        item["negatives"]["neg_type_4_tokens"] = tokenizer.convert_ids_to_tokens(neg_type_4_sampled_ids) # list of token strings

        # NEG TYPE 2
        if neg_type_2_ids_candidates:
            num_2 = min(len(neg_type_2_ids_candidates), random.randint(min_tokens, max_tokens))
            neg_type_2_sampled_ids = random.sample(neg_type_2_ids_candidates, num_2)
            item["negatives"]["neg_type_2_tokens"] = tokenizer.convert_ids_to_tokens(neg_type_2_sampled_ids) # list of token strings
        else:
            item["negatives"]["neg_type_2_tokens"] = []


        # NEG TYPE 3
        if neg_type_3_ids_candidates:
            num_3 = min(len(neg_type_3_ids_candidates), random.randint(min_tokens, max_tokens))
            neg_type_3_sampled_ids = random.sample(neg_type_3_ids_candidates, num_3)
            item["negatives"]["neg_type_3_tokens"] = tokenizer.convert_ids_to_tokens(neg_type_3_sampled_ids) # list of token strings
        else:
            item["negatives"]["neg_type_3_tokens"] = []


    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

    print(f"{output_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Construct positive and negative sample data")
    parser.add_argument("--input_path", type=str, required=True, help="Path to the input JSON file")
    parser.add_argument("--output_path", type=str, required=True, help="Path to the output JSON file")
    parser.add_argument("--model_name", type=str, default="all-MiniLM-L6-v2", help="Name of the SentenceTransformer model")
    parser.add_argument("--min_tokens", type=int, default=5, help="Minimum number of tokens")
    parser.add_argument("--max_tokens", type=int, default=20, help="Maximum number of tokens")
    parser.add_argument("--use_cache", action="store_true", help="Whether to use local cached vocab embeddings")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--top_k", type=int, default=1000, help="FAISS top-k retrieval (reserved, currently unused—actual value is max_tokens * 2)")
    parser.add_argument("--data_limit", type=int, default=-1, help="Maximum number of data entries to use; -1 means no limit")


    args = parser.parse_args()

    construct_negatives(
        input_path=args.input_path,
        output_path=args.output_path,
        model_name=args.model_name,
        min_tokens=args.min_tokens,
        max_tokens=args.max_tokens,
        data_limit=args.data_limit,
        use_cache=args.use_cache,
        seed=args.seed,
        top_k=args.top_k
    )